Hierarchically coupled causal models#

This notebook showcases how to generate hierarchically coupled causal models.

# Autoreload extension
%load_ext autoreload
%autoreload 2

%matplotlib inline
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
import torch
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib as mpl

from causaldynamics.utils import set_rng_seed
from causaldynamics.initialization import initialize_weights, initialize_biases, initialize_x, initialize_system_and_driver
from causaldynamics.scm import create_scm_graph, get_root_nodes_mask, GrowingNetworkWithRedirection
from causaldynamics.mlp import propagate_mlp
from causaldynamics.systems import solve_system, solve_random_systems
from causaldynamics.plot import plot_trajectories, animate_3d_trajectories, plot_scm
# Set a fixed seed for reproducibility
set_rng_seed(42)

# Increase the animation limit to 50MB
mpl.rcParams['animation.embed_limit'] = 50 * 1024**2 

Generate data the convenient way#

This is the recommended convenient way of generating data. If you need more flexibility, look at the step by step guide below.

from causaldynamics.creator import create_scm, simulate_system, create_plots

num_nodes = 2
node_dim = 3
num_timesteps = 200

confounders=False,
system_name='Lorenz'

A, W, b, root_nodes, magnitudes = create_scm(num_nodes, 
                                            node_dim,
                                            confounders=confounders)

data = simulate_system(A, W, b, 
                      num_timesteps=num_timesteps, 
                      num_nodes=num_nodes,
                      system_name=system_name) 

create_plots(
            data,
            A,
            root_nodes=root_nodes,
            out_dir='.',
            show_plot=True,
            save_plot=False,
            return_html_anim=True,
            create_animation=False,
        )
INFO - Creating SCM with 2 nodes and 3 dimensions each...
INFO - Simulating Lorenz system for 200 timesteps...
INFO - Generating visualizations
../_images/c97e3af103b542f52d2b4a14f2b734a6df1b17f52631b0e8372b63b33c8d2799.png ../_images/673882373d54507e030c55d57ad670932e60ff146849535eba074554d810cb95.png

Generate data step by step#

Let’s explore step-by-step what happens internally in the create_scm and simulate_system functions.

Let’s create a minimal example with a causal graph that has two nodes 0 and 1 and a single edge 1<-0. The signal we send through that graph is given by a Lorenz attractor with dimension 3 that is calculated for 50 time steps.

# Define parameters for a small test example
N_nodes = 2
N_timesteps = 1000
N_dimensions = 3

# Generate Lorenz attractor trajectory
d_lorenz = solve_system(N_timesteps, N_nodes, "Lorenz")
# Sample the simplest possible adjacency matrix: 0<--1
A = GrowingNetworkWithRedirection(N_nodes).generate()
A
tensor([[0., 0.],
        [1., 0.]])
# Sample random weights for the MLPs
W = initialize_weights(N_nodes, N_dimensions)
W
tensor([[[ 1.8306,  0.1129,  0.4252],
         [-1.3446,  1.6114,  0.5914],
         [-1.1415, -0.5930, -0.0058]],

        [[ 1.4170, -2.0788,  0.6370],
         [ 1.3824, -0.9156,  0.0159],
         [ 0.7230, -0.3092,  0.8614]]])
# Sample random biases for the MLPs
b = initialize_biases(N_nodes, N_dimensions)
b
tensor([[ 1.2865, -1.7983,  0.3016],
        [ 2.6508, -1.3340,  0.6698]])

Let’s plot the minimal test graph 1<–0 and propagate the Lorenz attractor through the causal graph.

First, nitialize the root node 1 with the time-series data from the Lorenz attractor. Then, propagate this signal through the edge 1->0. The signal is the input to the sampled multilayer perceptron (MLP) with no activation function. The output is the state of node 0.

# Initialize the root nodes with the Lorenz attractor
init = initialize_x(d_lorenz, A)

# Propagate Lorenz trajectory through the SCM
x = propagate_mlp(A, W, b, init=init)

# Visualize the results using xarray
da = xr.DataArray(x, dims=['time', 'node', "dim"])
root_nodes = get_root_nodes_mask(A)

plot_scm(G=create_scm_graph(A), root_nodes=root_nodes)
plt.show()

plot_trajectories(da, root_nodes=root_nodes, sharey=False)
plt.show()
../_images/c97e3af103b542f52d2b4a14f2b734a6df1b17f52631b0e8372b63b33c8d2799.png ../_images/fbae8777276bddb59148314e479dd611cea5eebc1d655c89b7b63261ab84f74f.png
# Create an animation of the simple coupled Lorenz attractor. 
da_sel = da.isel(time=slice(0, 500))
anim = animate_3d_trajectories(da_sel, frame_skip=2, rotate=True , show_history=True, plot_type='subplots', root_nodes=root_nodes)
display(anim)

For convenience and ease of use, we provide the simulate_system function.

Stochastic dynamical systems#

Now randomly select a dynamical system and add noise.

# Set parameters
num_nodes = 2
num_timesteps = 1000
dimensions = 3

# Sample the SCM, all hyperparameters and propagate the Lorenz attractor through the SCM
A = GrowingNetworkWithRedirection(num_nodes).generate()
root_nodes = get_root_nodes_mask(A)
W = initialize_weights(num_nodes, dimensions)
b = initialize_biases(num_nodes, dimensions)
sol = solve_random_systems(num_timesteps, num_nodes, make_trajectory_kwargs={'noise': 0.5})
init = initialize_x(sol, A)

x = propagate_mlp(A, W, b, init=init)
da = xr.DataArray(x, dims=['time', 'node', "dim"])
# Plot the SCM graph
plot_scm(G=create_scm_graph(A), root_nodes=root_nodes)
plt.show()

# Plot the trajectories
plot_trajectories(da, root_nodes, sharey=False)
plt.show()

# # Animate the trajectories. This takes a while to run...
# anim = animate_3d_trajectories(da, frame_skip=5, rotation_speed=0.2, rotate=True , show_history=True, plot_type='subplots', root_nodes=root_nodes)
# display(anim)
../_images/ac6ec819c4a55ede80438875d2307eb5971d9e55885b497f85d4460a69bacb30.png ../_images/b26b62828cc53ba07c15c52ad56bcf50f441c6982771c275f0e25321c124d409.png

Internal standardization (iSCM v.s. SCM)#

Let’s look at the difference between internally-standardized SCMs (iSCMs) and not standardized SCMs.

# Set parameters
num_nodes = 2
num_timesteps = 1000
dimensions = 3

# Sample the SCM, all hyperparameters and propagate the Lorenz attractor through the SCM
A = GrowingNetworkWithRedirection(num_nodes).generate()
root_nodes = get_root_nodes_mask(A)
W = initialize_weights(num_nodes, dimensions)
b = initialize_biases(num_nodes, dimensions)
d_lorenz = solve_system(num_timesteps, num_nodes, "Lorenz")
init = initialize_x(d_lorenz, A)

x = propagate_mlp(A, W, b, init=init)
da = xr.DataArray(x, dims=['time', 'node', "dim"])

x_s = propagate_mlp(A, W, b, init=init, standardize=True)
da_s = xr.DataArray(x_s, dims=['time', 'node', "dim"])
# Plot the SCM graph
plot_scm(G=create_scm_graph(A), root_nodes=root_nodes)
plt.show()

# Plot the trajectories
plot_trajectories(da, root_nodes, sharey=False)
da_s_sel = da_s.isel(time=slice(None, None, 10))
plot_trajectories(da_s_sel, root_nodes, sharey=False)
plt.show()

# # Animate the trajectories. This takes a while to run...
# anim = animate_3d_trajectories(da, frame_skip=5, rotation_speed=0.2, rotate=True , show_history=True, plot_type='subplots', root_nodes=root_nodes)
# display(anim)
../_images/9b8e894d776e2b04745cf45976aa91ad673a95de70e0b72ad7fa68748aa6427e.png ../_images/d23301556c2c8ecdfd37b81014208b2eb9521eb8e6f4486d3a3bdb92c4ccdae6.png ../_images/b6a42df7fdda2b82eeb54b224abbe38d67e96f7671e89dd2b50a07ad42614c3e.png

Larger dynamical system#

Let’s look at an example of randomly selected dynamical system and larger SCMs with 10 nodes.

# This takes a while to run...
num_nodes = 10
num_timesteps = 500
dimensions = 3

# Sample the SCM, all hyperoparameters and propagate the Lorenz attractor through the SCM
A = GrowingNetworkWithRedirection(num_nodes).generate()
W = initialize_weights(num_nodes, dimensions)
b = initialize_biases(num_nodes, dimensions)
sol = solve_random_systems(num_timesteps, num_nodes, make_trajectory_kwargs={})
init = initialize_x(sol, A)
x = propagate_mlp(A, W, b, init=init)
root_nodes = get_root_nodes_mask(A)
da = xr.DataArray(x, dims=['time', 'node', "dim"])
/Users/herdeanu/kausable/causaldynamics/.venv/lib/python3.10/site-packages/dysts/base.py:353: UserWarning: This system has at least one unbounded variable, which has been mapped to a bounded domain. Pass argument postprocess=False in order to generate trajectories from the raw system.
  warnings.warn(
/Users/herdeanu/kausable/causaldynamics/.venv/lib/python3.10/site-packages/dysts/base.py:342: UserWarning: SprottJerk: Integration did not complete for initial condition [-0.45389498  0.49816488  0.37857162], only got 125 points. Skipping this point
  warnings.warn(
/Users/herdeanu/kausable/causaldynamics/.venv/lib/python3.10/site-packages/dysts/base.py:353: UserWarning: This system has at least one unbounded variable, which has been mapped to a bounded domain. Pass argument postprocess=False in order to generate trajectories from the raw system.
  warnings.warn(
/Users/herdeanu/kausable/causaldynamics/.venv/lib/python3.10/site-packages/dysts/base.py:353: UserWarning: This system has at least one unbounded variable, which has been mapped to a bounded domain. Pass argument postprocess=False in order to generate trajectories from the raw system.
  warnings.warn(
# Plot the SCM graph
plot_scm(G=create_scm_graph(A), root_nodes=root_nodes)
plt.show()

# Plot the trajectories
plot_trajectories(da, root_nodes, sharey=False)
plt.show()

# # Animate the trajectories
# anim = animate_3d_trajectories(da, frame_skip=2, rotate=True , show_history=True, plot_type='subplots', root_nodes=root_nodes)
# display(anim)
../_images/6ed2925939040fcf7e7e6d3e2fb30b59508ff3e04ac3b68e8fc233d3fed19552.png ../_images/80628fa950132079e86b6e18ea4a00ca848b94dc50a1eb9cce69ee4ab8bec7d4.png

Add periodic and linear drivers#

# This takes a while to run...
num_nodes = 10
num_timesteps = 500
dimensions = 3
init_ratios = [1, 1, 1] # Set ratios of dynamical systems, periodic and linear drivers at root nodes. Here: equal ratio.

# Sample the SCM, all hyperoparameters and propagate the Lorenz attractor through the SCM
A = GrowingNetworkWithRedirection(num_nodes).generate()
W = initialize_weights(num_nodes, dimensions)
b = initialize_biases(num_nodes, dimensions)
sol = initialize_system_and_driver(num_timesteps, num_nodes, init_ratios=init_ratios, system_name='random', node_dim=dimensions, time_lag=0, device=None, make_trajectory_kwargs={})
init = initialize_x(sol, A)
x = propagate_mlp(A, W, b, init=init)
root_nodes = get_root_nodes_mask(A)
da = xr.DataArray(x, dims=['time', 'node', "dim"])
/Users/herdeanu/kausable/causaldynamics/.venv/lib/python3.10/site-packages/dysts/base.py:342: UserWarning: SprottJerk: Integration did not complete for initial condition [-0.03610093  0.03962197  0.03011002], only got 374 points. Skipping this point
  warnings.warn(
/Users/herdeanu/kausable/causaldynamics/.venv/lib/python3.10/site-packages/dysts/base.py:353: UserWarning: This system has at least one unbounded variable, which has been mapped to a bounded domain. Pass argument postprocess=False in order to generate trajectories from the raw system.
  warnings.warn(
# Plot the SCM graph
plot_scm(G=create_scm_graph(A), root_nodes=root_nodes)
plt.show()

# Plot the trajectories
plot_trajectories(da, root_nodes, sharey=False)
plt.show()

# # Animate the trajectories
# anim = animate_3d_trajectories(da, frame_skip=2, rotate=True , show_history=True, plot_type='subplots', root_nodes=root_nodes)
# display(anim)
../_images/6b0856d1d292464efa418cd51055dc6b1a39ca3d484e5e0244a90f7e8d3bb72f.png ../_images/8f2d01943a74f054dc53f4a7585aeeb0ab1f50451bb95a66b171353de62cddbc.png

Specify graph structures#

Let’s start by focusing on a small system first again. The SCM graph is: 1->0<-2

# Set parameters
num_nodes = 3
num_timesteps = 200
dimensions = 3

# Sample the SCM, all hyperoparameters and propagate the Lorenz attractor through the SCM
A = torch.tensor([[0., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]])
W = initialize_weights(num_nodes, dimensions)
b = initialize_biases(num_nodes, dimensions)
init = solve_system(num_timesteps, num_nodes, "Lorenz")
init2 = solve_system(num_timesteps, num_nodes, "Rossler")
init[:, 1] = init2[:, 1]
init = initialize_x(init, A)
x = propagate_mlp(A, W, b, init=init)
root_nodes = get_root_nodes_mask(A)
da = xr.DataArray(x, dims=['time', 'node', "dim"])
# Plot the SCM graph
plot_scm(G=create_scm_graph(A), root_nodes=root_nodes)
plt.show()

# Plot the trajectories
plot_trajectories(da, root_nodes, sharey=False)
plt.show()

# # Animate the trajectories
# anim = animate_3d_trajectories(da, frame_skip=2, rotate=True , show_history=True, plot_type='subplots', root_nodes=root_nodes)
# display(anim)
../_images/14ce082d5cbe62bf288d1c37b8d2e73560284d8310d81f25431a9c44c4599112.png ../_images/eda3c8ba58c8b34ffeeb3bbb23562ed7c80d51cbc42b2454c39fe18253216653.png

Now, we look at a linear chain: 2->1->0

# Sample the SCM, all hyperoparameters and propagate the Lorenz attractor through the SCM
A = torch.tensor([[0., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.]])
W = initialize_weights(num_nodes, dimensions)
b = initialize_biases(num_nodes, dimensions)
init = solve_system(num_timesteps, num_nodes, "Lorenz")
init = initialize_x(init, A)
x = propagate_mlp(A, W, b, init=init)
root_nodes = get_root_nodes_mask(A)
da = xr.DataArray(x, dims=['time', 'node', "dim"])
# Plot the SCM graph
plot_scm(G=create_scm_graph(A), root_nodes=root_nodes)
plt.show()

# Plot the trajectories
plot_trajectories(da, root_nodes, sharey=False)
plt.show()

# # Animate the trajectories
# anim = animate_3d_trajectories(da, frame_skip=2, rotate=True , show_history=True, plot_type='subplots', root_nodes=root_nodes)
# display(anim)
../_images/8d4a3fa902c02fb9a482378e0f5fac3515a9bd9c66afe3b7623af1cae81a8dcd.png ../_images/e583c7abfdc24dcf31fce33627dd14083b0aca5e71389f7074cc0160abeb9cdf.png

Loading precomputed chaotic systems#

import xarray as xr
import torch

data_dir = "output/chaotic_systems_dim3/20250410_200501-chaosys-dim3-N100_T2000/data" # Change this to the directory containing the precomputed chaotic systems
files = ['Lorenz_N5_T100.nc', 'Rossler_N5_T100.nc'] # Change to the files you want to load

paths = [data_dir + "/" + system for system in files]
ds = [xr.load_dataset(path)['time_series'] for path in paths]
sel = torch.tensor(xr.concat(ds, dim='systems').values).permute(1, 0, 2)
init = initialize_x(sel, A) # This can then be used to initialize the SCM